// Copyright Epic Games, Inc. All Rights Reserved.

#include "cas.h"

#include "compactcas.h"
#include "filecas.h"

#include <zencore/compactbinary.h>
#include <zencore/compactbinarybuilder.h>
#include <zencore/compactbinaryvalidation.h>
#include <zencore/except.h>
#include <zencore/fmtutils.h>
#include <zencore/logging.h>
#include <zencore/memory.h>
#include <zencore/scopeguard.h>
#include <zencore/string.h>
#include <zencore/testing.h>
#include <zencore/testutils.h>
#include <zencore/thread.h>
#include <zencore/trace.h>
#include <zencore/uid.h>
#include <zencore/workthreadpool.h>
#include <zenstore/cidstore.h>
#include <zenstore/gc.h>
#include <zenstore/scrubcontext.h>
#include <zenutil/workerpools.h>

#include <gsl/gsl-lite.hpp>

#include <filesystem>
#include <functional>
#include <unordered_map>

//////////////////////////////////////////////////////////////////////////

namespace zen {

/**
 * CAS store implementation
 *
 * Uses a basic strategy of splitting payloads by size, to improve ability to reclaim space
 * quickly for unused large chunks and to maintain locality for small chunks which are
 * frequently accessed together.
 *
 */
class CasImpl : public CasStore
{
public:
	CasImpl(GcManager& Gc);
	virtual ~CasImpl();

	virtual void					  Initialize(const CidStoreConfiguration& InConfig) override;
	virtual CasStore::InsertResult	  InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash, InsertMode Mode) override;
	virtual std::vector<InsertResult> InsertChunks(std::span<IoBuffer> Data,
												   std::span<IoHash>   ChunkHashes,
												   InsertMode		   Mode = InsertMode::kMayBeMovedInPlace) override;
	virtual IoBuffer				  FindChunk(const IoHash& ChunkHash) override;
	virtual bool					  ContainsChunk(const IoHash& ChunkHash) override;
	virtual void					  FilterChunks(HashKeySet& InOutChunks) override;
	virtual bool					  IterateChunks(std::span<IoHash>												  DecompressedIds,
													const std::function<bool(size_t Index, const IoBuffer& Payload)>& AsyncCallback,
													WorkerThreadPool*												  OptionalWorkerPool) override;
	virtual void					  Flush() override;
	virtual void					  ScrubStorage(ScrubContext& Ctx) override;
	virtual void					  GarbageCollect(GcContext& GcCtx) override;
	virtual CidStoreSize			  TotalSize() const override;

private:
	CasContainerStrategy m_TinyStrategy;
	CasContainerStrategy m_SmallStrategy;
	FileCasStrategy		 m_LargeStrategy;
	CbObject			 m_ManifestObject;

	enum class StorageScheme
	{
		Legacy		   = 0,
		WithCbManifest = 1
	};

	StorageScheme m_StorageScheme = StorageScheme::Legacy;

	bool OpenOrCreateManifest();
	void UpdateManifest();
};

//////////////////////////////////////////////////////////////////////////

CasImpl::CasImpl(GcManager& Gc) : m_TinyStrategy(Gc), m_SmallStrategy(Gc), m_LargeStrategy(Gc)
{
}

CasImpl::~CasImpl()
{
	ZEN_INFO("closing Cas pool at '{}'", m_Config.RootDirectory);
}

void
CasImpl::Initialize(const CidStoreConfiguration& InConfig)
{
	ZEN_TRACE_CPU("CAS::Initialize");

	m_Config = InConfig;

	ZEN_INFO("initializing CAS pool at '{}'", m_Config.RootDirectory);

	// Ensure root directory exists - create if it doesn't exist already

	std::filesystem::create_directories(m_Config.RootDirectory);

	// Open or create manifest

	const bool IsNewStore = OpenOrCreateManifest();

	// Initialize payload storage
	{
		WorkerThreadPool&			   WorkerPool = GetMediumWorkerPool();
		std::vector<std::future<void>> Work;
		Work.emplace_back(
			WorkerPool.EnqueueTask(std::packaged_task<void()>{[&]() { m_LargeStrategy.Initialize(m_Config.RootDirectory, IsNewStore); }}));
		Work.emplace_back(WorkerPool.EnqueueTask(std::packaged_task<void()>{[&]() {
			m_TinyStrategy.Initialize(m_Config.RootDirectory, "tobs", 1u << 28, 16, IsNewStore);  // 256 Mb per block
		}}));
		Work.emplace_back(WorkerPool.EnqueueTask(std::packaged_task<void()>{[&]() {
			m_SmallStrategy.Initialize(m_Config.RootDirectory, "sobs", 1u << 30, 4096, IsNewStore);	 // 1 Gb per block
		}}));
		for (std::future<void>& Result : Work)
		{
			if (Result.valid())
			{
				Result.wait();
			}
		}
		for (std::future<void>& Result : Work)
		{
			Result.get();
		}
	}
}

bool
CasImpl::OpenOrCreateManifest()
{
	ZEN_TRACE_CPU("CAS::OpenOrCreateManifest");
	bool IsNewStore = false;

	std::filesystem::path ManifestPath = m_Config.RootDirectory;
	ManifestPath /= ".ucas_root";

	std::error_code Ec;
	BasicFile		ManifestFile;
	ManifestFile.Open(ManifestPath.c_str(), BasicFile::Mode::kRead, Ec);

	bool ManifestIsOk = false;

	if (Ec)
	{
		if (Ec == std::errc::no_such_file_or_directory)
		{
			IsNewStore = true;
		}
	}
	else
	{
		IoBuffer ManifestBuffer = ManifestFile.ReadAll();
		ManifestFile.Close();

		if (ManifestBuffer.Size() > 0 && ManifestBuffer.Data<uint8_t>()[0] == '#')
		{
			// Old-style manifest, does not contain any useful information, so we may as well update it
		}
		else
		{
			CbObject		Manifest{SharedBuffer(ManifestBuffer)};
			CbValidateError ValidationResult = ValidateCompactBinary(ManifestBuffer, CbValidateMode::All);

			if (ValidationResult == CbValidateError::None)
			{
				if (Manifest["id"])
				{
					ManifestIsOk = true;
				}
			}
			else
			{
				ZEN_WARN("Store manifest validation failed: {:#x}, will generate new manifest to recover", uint32_t(ValidationResult));
			}

			if (ManifestIsOk)
			{
				m_ManifestObject = std::move(Manifest);
			}
		}
	}

	if (!ManifestIsOk)
	{
		UpdateManifest();
	}

	return IsNewStore;
}

void
CasImpl::UpdateManifest()
{
	ZEN_TRACE_CPU("CAS::UpdateManifest");
	if (!m_ManifestObject)
	{
		CbObjectWriter Cbo;
		Cbo << "id" << zen::Oid::NewOid() << "created" << DateTime::Now();
		m_ManifestObject = Cbo.Save();
	}

	// Write manifest to file

	std::filesystem::path ManifestPath = m_Config.RootDirectory;
	ManifestPath /= ".ucas_root";

	// This will throw on failure

	ZEN_TRACE("Writing new manifest to '{}'", ManifestPath);

	TemporaryFile::SafeWriteFile(ManifestPath, m_ManifestObject.GetBuffer().GetView());
}

CasStore::InsertResult
CasImpl::InsertChunk(IoBuffer Chunk, const IoHash& ChunkHash, InsertMode Mode)
{
	ZEN_TRACE_CPU("CAS::InsertChunk");

	const uint64_t ChunkSize = Chunk.Size();

	if (ChunkSize < m_Config.TinyValueThreshold)
	{
		ZEN_ASSERT(ChunkSize);

		return m_TinyStrategy.InsertChunk(Chunk, ChunkHash);
	}
	else if (ChunkSize < m_Config.HugeValueThreshold)
	{
		return m_SmallStrategy.InsertChunk(Chunk, ChunkHash);
	}

	return m_LargeStrategy.InsertChunk(Chunk, ChunkHash, Mode);
}

static void
GetCompactCasResults(CasContainerStrategy&				 Strategy,
					 std::span<IoBuffer>				 Data,
					 std::span<IoHash>					 ChunkHashes,
					 std::span<size_t>					 Indexes,
					 std::vector<CasStore::InsertResult> Results)
{
	const size_t Count = Indexes.size();
	if (Count == 1)
	{
		const size_t Index = Indexes[0];
		Results[Index]	   = Strategy.InsertChunk(Data[Index], ChunkHashes[Index]);
		return;
	}
	std::vector<IoBuffer> Chunks;
	std::vector<IoHash>	  Hashes;
	Chunks.reserve(Count);
	Hashes.reserve(Count);
	for (size_t Index : Indexes)
	{
		Chunks.push_back(Data[Index]);
		Hashes.push_back(ChunkHashes[Index]);
	}

	Strategy.InsertChunks(Chunks, Hashes);

	for (size_t Offset = 0; Offset < Count; Offset++)
	{
		size_t Index   = Indexes[Offset];
		Results[Index] = Results[Offset];
	}
};

static void
GetFileCasResults(FileCasStrategy&					  Strategy,
				  CasStore::InsertMode				  Mode,
				  std::span<IoBuffer>				  Data,
				  std::span<IoHash>					  ChunkHashes,
				  std::span<size_t>					  Indexes,
				  std::vector<CasStore::InsertResult> Results)
{
	for (size_t Index : Indexes)
	{
		Results[Index] = Strategy.InsertChunk(Data[Index], ChunkHashes[Index], Mode);
	}
};

std::vector<CasStore::InsertResult>
CasImpl::InsertChunks(std::span<IoBuffer> Data, std::span<IoHash> ChunkHashes, CasStore::InsertMode Mode)
{
	ZEN_TRACE_CPU("CAS::InsertChunks");
	ZEN_ASSERT(Data.size() == ChunkHashes.size());

	if (Data.size() == 1)
	{
		std::vector<CasStore::InsertResult> Result(1);
		Result[0] = InsertChunk(Data[0], ChunkHashes[0], Mode);
		return Result;
	}

	std::vector<size_t> TinyIndexes;
	std::vector<size_t> SmallIndexes;
	std::vector<size_t> LargeIndexes;

	for (size_t Index = 0; Index < Data.size(); Index++)
	{
		const uint64_t ChunkSize = Data[Index].Size();
		if (ChunkSize < m_Config.TinyValueThreshold)
		{
			ZEN_ASSERT(ChunkSize);
			TinyIndexes.push_back(Index);
		}
		else if (ChunkSize < m_Config.HugeValueThreshold)
		{
			SmallIndexes.push_back(Index);
		}
		else
		{
			LargeIndexes.push_back(Index);
		}
	}

	std::vector<CasStore::InsertResult> Result(Data.size());

	if (!TinyIndexes.empty())
	{
		GetCompactCasResults(m_TinyStrategy, Data, ChunkHashes, TinyIndexes, Result);
	}

	if (!SmallIndexes.empty())
	{
		GetCompactCasResults(m_SmallStrategy, Data, ChunkHashes, SmallIndexes, Result);
	}

	if (!LargeIndexes.empty())
	{
		GetFileCasResults(m_LargeStrategy, Mode, Data, ChunkHashes, LargeIndexes, Result);
	}

	return Result;
}

IoBuffer
CasImpl::FindChunk(const IoHash& ChunkHash)
{
	ZEN_TRACE_CPU("CAS::FindChunk");

	if (IoBuffer Found = m_SmallStrategy.FindChunk(ChunkHash))
	{
		Found.SetContentType(ZenContentType::kCompressedBinary);
		return Found;
	}

	if (IoBuffer Found = m_TinyStrategy.FindChunk(ChunkHash))
	{
		Found.SetContentType(ZenContentType::kCompressedBinary);
		return Found;
	}

	if (IoBuffer Found = m_LargeStrategy.FindChunk(ChunkHash))
	{
		Found.SetContentType(ZenContentType::kCompressedBinary);
		return Found;
	}

	// Not found
	return IoBuffer{};
}

bool
CasImpl::ContainsChunk(const IoHash& ChunkHash)
{
	ZEN_TRACE_CPU("CAS::ContainsChunk");

	return m_SmallStrategy.HaveChunk(ChunkHash) || m_TinyStrategy.HaveChunk(ChunkHash) || m_LargeStrategy.HaveChunk(ChunkHash);
}

void
CasImpl::FilterChunks(HashKeySet& InOutChunks)
{
	ZEN_TRACE_CPU("CAS::FilterChunks");
	m_SmallStrategy.FilterChunks(InOutChunks);
	m_TinyStrategy.FilterChunks(InOutChunks);
	m_LargeStrategy.FilterChunks(InOutChunks);
}

bool
CasImpl::IterateChunks(std::span<IoHash>												 DecompressedIds,
					   const std::function<bool(size_t Index, const IoBuffer& Payload)>& AsyncCallback,
					   WorkerThreadPool*												 OptionalWorkerPool)
{
	ZEN_TRACE_CPU("CAS::IterateChunks");
	if (!m_SmallStrategy.IterateChunks(
			DecompressedIds,
			[&](size_t Index, const IoBuffer& Payload) {
				IoBuffer Chunk(Payload);
				Chunk.SetContentType(ZenContentType::kCompressedBinary);
				return AsyncCallback(Index, Payload);
			},
			OptionalWorkerPool))
	{
		return false;
	}
	if (!m_TinyStrategy.IterateChunks(
			DecompressedIds,
			[&](size_t Index, const IoBuffer& Payload) {
				IoBuffer Chunk(Payload);
				Chunk.SetContentType(ZenContentType::kCompressedBinary);
				return AsyncCallback(Index, Payload);
			},
			OptionalWorkerPool))
	{
		return false;
	}
	if (!m_LargeStrategy.IterateChunks(
			DecompressedIds,
			[&](size_t Index, const IoBuffer& Payload) {
				IoBuffer Chunk(Payload);
				Chunk.SetContentType(ZenContentType::kCompressedBinary);
				return AsyncCallback(Index, Payload);
			},
			OptionalWorkerPool))
	{
		return false;
	}
	return true;
}

void
CasImpl::Flush()
{
	ZEN_TRACE_CPU("CAS::Flush");
	ZEN_INFO("flushing CAS pool at '{}'", m_Config.RootDirectory);
	m_SmallStrategy.Flush();
	m_TinyStrategy.Flush();
	m_LargeStrategy.Flush();
}

void
CasImpl::ScrubStorage(ScrubContext& Ctx)
{
	ZEN_TRACE_CPU("Cas::ScrubStorage");

	if (m_LastScrubTime == Ctx.ScrubTimestamp())
	{
		return;
	}

	m_LastScrubTime = Ctx.ScrubTimestamp();

	m_SmallStrategy.ScrubStorage(Ctx);
	m_TinyStrategy.ScrubStorage(Ctx);
	m_LargeStrategy.ScrubStorage(Ctx);
}

void
CasImpl::GarbageCollect(GcContext& GcCtx)
{
	ZEN_TRACE_CPU("Cas::GarbageCollect");

	m_SmallStrategy.CollectGarbage(GcCtx);
	m_TinyStrategy.CollectGarbage(GcCtx);
	m_LargeStrategy.CollectGarbage(GcCtx);
}

CidStoreSize
CasImpl::TotalSize() const
{
	const uint64_t Tiny	 = m_TinyStrategy.StorageSize().DiskSize;
	const uint64_t Small = m_SmallStrategy.StorageSize().DiskSize;
	const uint64_t Large = m_LargeStrategy.StorageSize().DiskSize;

	return {.TinySize = Tiny, .SmallSize = Small, .LargeSize = Large, .TotalSize = Tiny + Small + Large};
}

//////////////////////////////////////////////////////////////////////////

std::unique_ptr<CasStore>
CreateCasStore(GcManager& Gc)
{
	return std::make_unique<CasImpl>(Gc);
}

//////////////////////////////////////////////////////////////////////////
//
// Testing related code follows...
//

#if ZEN_WITH_TESTS

TEST_CASE("CasStore")
{
	ScopedTemporaryDirectory TempDir;

	CidStoreConfiguration config;
	config.RootDirectory = TempDir.Path();

	GcManager Gc;

	std::unique_ptr<CasStore> Store = CreateCasStore(Gc);
	Store->Initialize(config);

	WorkerThreadPool ThreadPool{1};
	ScrubContext	 Ctx{ThreadPool};
	Store->ScrubStorage(Ctx);

	IoBuffer Value1{16};
	memcpy(Value1.MutableData(), "1234567890123456", 16);
	IoHash				   Hash1   = IoHash::HashBuffer(Value1.Data(), Value1.Size());
	CasStore::InsertResult Result1 = Store->InsertChunk(Value1, Hash1);
	CHECK(Result1.New);

	IoBuffer Value2{16};
	memcpy(Value2.MutableData(), "ABCDEFGHIJKLMNOP", 16);
	IoHash				   Hash2   = IoHash::HashBuffer(Value2.Data(), Value2.Size());
	CasStore::InsertResult Result2 = Store->InsertChunk(Value2, Hash2);
	CHECK(Result2.New);

	HashKeySet ChunkSet;
	ChunkSet.AddHashToSet(Hash1);
	ChunkSet.AddHashToSet(Hash2);

	Store->FilterChunks(ChunkSet);
	CHECK(ChunkSet.IsEmpty());

	IoBuffer Lookup1 = Store->FindChunk(Hash1);
	CHECK(Lookup1);
	IoBuffer Lookup2 = Store->FindChunk(Hash2);
	CHECK(Lookup2);
}

void
CAS_forcelink()
{
}

#endif

}  // namespace zen
